#  Copyright 2025 Collate
#  Licensed under the Collate Community License, Version 1.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""
Databricks Unity Catalog Lineage Source Module
"""
import traceback
from typing import Iterable, Optional

from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.container import ContainerDataModel
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.unityCatalogConnection import (
    UnityCatalogConnection,
)
from metadata.generated.schema.metadataIngestion.workflow import (
    Source as WorkflowSource,
)
from metadata.generated.schema.type.entityLineage import (
    ColumnLineage,
    EntitiesEdge,
    LineageDetails,
)
from metadata.generated.schema.type.entityLineage import Source as LineageSource
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException, Source
from metadata.ingestion.lineage.sql_lineage import get_column_fqn
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import test_connection_common
from metadata.ingestion.source.database.unitycatalog.client import UnityCatalogClient
from metadata.ingestion.source.database.unitycatalog.connection import get_connection
from metadata.ingestion.source.database.unitycatalog.models import LineageTableStreams
from metadata.utils import fqn
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.helpers import retry_with_docker_host
from metadata.utils.logger import ingestion_logger

logger = ingestion_logger()


class UnitycatalogLineageSource(Source):
    """
    Lineage Unity Catalog Source
    """

    @retry_with_docker_host()
    def __init__(
        self,
        config: WorkflowSource,
        metadata: OpenMetadata,
    ):
        super().__init__()
        self.config = config
        self.metadata = metadata
        self.service_connection = self.config.serviceConnection.root.config
        self.source_config = self.config.sourceConfig.config
        self.client = UnityCatalogClient(self.service_connection)
        self.connection_obj = get_connection(self.service_connection)
        self.test_connection()

    def close(self):
        """
        By default, there is nothing to close
        """

    def prepare(self):
        """
        By default, there's nothing to prepare
        """

    @classmethod
    def create(
        cls, config_dict, metadata: OpenMetadata, pipeline_name: Optional[str] = None
    ):
        """Create class instance"""
        config: WorkflowSource = WorkflowSource.model_validate(config_dict)
        connection: UnityCatalogConnection = config.serviceConnection.root.config
        if not isinstance(connection, UnityCatalogConnection):
            raise InvalidSourceException(
                f"Expected UnityCatalogConnection, but got {connection}"
            )
        return cls(config, metadata)

    def _get_data_model_column_fqn(
        self, data_model_entity: ContainerDataModel, column: str
    ) -> Optional[str]:
        if not data_model_entity:
            return None
        for entity_column in data_model_entity.columns:
            if entity_column.displayName.lower() == column.lower():
                return entity_column.fullyQualifiedName.root
        return None

    def _get_container_column_lineage(
        self, data_model_entity: ContainerDataModel, table_entity: Table
    ) -> Optional[LineageDetails]:
        try:
            column_lineage = []
            for column in table_entity.columns:
                from_column = self._get_data_model_column_fqn(
                    data_model_entity=data_model_entity, column=column.name.root
                )
                to_column = column.fullyQualifiedName.root
                if from_column and to_column:
                    column_lineage.append(
                        ColumnLineage(fromColumns=[from_column], toColumn=to_column)
                    )
            if column_lineage:
                return LineageDetails(
                    columnsLineage=column_lineage,
                    source=LineageSource.ExternalTableLineage,
                )
            return None
        except Exception as exc:
            logger.debug(f"Error computing container column lineage: {exc}")
            logger.debug(traceback.format_exc())
            return None

    def _get_lineage_details(
        self, from_table: Table, to_table: Table, databricks_table_fqn: str
    ) -> Optional[LineageDetails]:
        try:
            col_lineage = []
            for column in to_table.columns:
                column_streams = self.client.get_column_lineage(
                    databricks_table_fqn, column_name=column.name.root
                )
                from_columns = []
                for col in column_streams.upstream_cols:
                    col_fqn = get_column_fqn(from_table, col.name)
                    if col_fqn:
                        from_columns.append(col_fqn)

                if from_columns:
                    col_lineage.append(
                        ColumnLineage(
                            fromColumns=from_columns,
                            toColumn=column.fullyQualifiedName.root,
                        )
                    )
            if col_lineage:
                return LineageDetails(
                    columnsLineage=col_lineage, source=LineageSource.QueryLineage
                )
            return None
        except Exception as exc:
            logger.debug(
                f"Error computing column lineage for {to_table.fullyQualifiedName.root} - {exc}"
            )
            logger.debug(traceback.format_exc())
            return None

    def _handle_external_location_lineage(
        self, file_info, table: Table, is_upstream: bool
    ) -> Iterable[Either[AddLineageRequest]]:
        try:
            if not file_info.storage_location:
                logger.debug("No storage location found in fileInfo")
                return

            storage_location = file_info.storage_location.rstrip("/")
            location_entity = self.metadata.es_search_container_by_path(
                full_path=storage_location, fields="dataModel"
            )

            if location_entity and location_entity[0]:
                lineage_details = None
                if location_entity[0].dataModel:
                    lineage_details = self._get_container_column_lineage(
                        location_entity[0].dataModel, table
                    )

                if is_upstream:
                    yield Either(
                        left=None,
                        right=AddLineageRequest(
                            edge=EntitiesEdge(
                                fromEntity=EntityReference(
                                    id=location_entity[0].id,
                                    type="container",
                                ),
                                toEntity=EntityReference(
                                    id=table.id,
                                    type="table",
                                ),
                                lineageDetails=lineage_details,
                            )
                        ),
                    )
                else:
                    yield Either(
                        left=None,
                        right=AddLineageRequest(
                            edge=EntitiesEdge(
                                fromEntity=EntityReference(
                                    id=table.id,
                                    type="table",
                                ),
                                toEntity=EntityReference(
                                    id=location_entity[0].id,
                                    type="container",
                                ),
                                lineageDetails=lineage_details,
                            )
                        ),
                    )
            else:
                logger.debug(
                    f"Unable to find container for external location: {storage_location}"
                )
        except Exception as exc:
            logger.debug(
                f"Error while processing external location lineage for {file_info.storage_location}: {exc}"
            )
            logger.debug(traceback.format_exc())

    def _handle_upstream_table(
        self,
        table_streams: LineageTableStreams,
        table: Table,
        databricks_table_fqn: str,
    ) -> Iterable[Either[AddLineageRequest]]:
        for upstream_entity in table_streams.upstreams:
            try:
                if upstream_entity.fileInfo:
                    yield from self._handle_external_location_lineage(
                        upstream_entity.fileInfo, table, is_upstream=True
                    )
                    continue

                if not upstream_entity.tableInfo or not upstream_entity.tableInfo.name:
                    continue

                upstream_table = upstream_entity.tableInfo
                from_entity_fqn = fqn.build(
                    metadata=self.metadata,
                    entity_type=Table,
                    database_name=upstream_table.catalog_name,
                    schema_name=upstream_table.schema_name,
                    table_name=upstream_table.name,
                    service_name=self.config.serviceName,
                )

                from_entity = self.metadata.get_by_name(
                    entity=Table, fqn=from_entity_fqn
                )
                if from_entity:
                    lineage_details = self._get_lineage_details(
                        from_table=from_entity,
                        to_table=table,
                        databricks_table_fqn=databricks_table_fqn,
                    )
                    yield Either(
                        left=None,
                        right=AddLineageRequest(
                            edge=EntitiesEdge(
                                toEntity=EntityReference(id=table.id, type="table"),
                                fromEntity=EntityReference(
                                    id=from_entity.id, type="table"
                                ),
                                lineageDetails=lineage_details,
                            )
                        ),
                    )
                else:
                    logger.debug(
                        f"Unable to find upstream entity for "
                        f"{upstream_table.catalog_name}.{upstream_table.schema_name}.{upstream_table.name}"
                        f" -> {databricks_table_fqn}"
                    )
            except Exception:
                logger.debug(
                    "Error while processing upstream lineage for "
                    f"{databricks_table_fqn}"
                )
                logger.debug(traceback.format_exc())

    def _handle_downstream_table(
        self,
        table_streams: LineageTableStreams,
        table: Table,
        databricks_table_fqn: str,
    ) -> Iterable[Either[AddLineageRequest]]:
        for downstream_entity in table_streams.downstreams:
            try:
                if downstream_entity.fileInfo:
                    yield from self._handle_external_location_lineage(
                        downstream_entity.fileInfo, table, is_upstream=False
                    )
                    continue

                if (
                    not downstream_entity.tableInfo
                    or not downstream_entity.tableInfo.name
                ):
                    continue

                downstream_table = downstream_entity.tableInfo
                to_entity_fqn = fqn.build(
                    metadata=self.metadata,
                    entity_type=Table,
                    database_name=downstream_table.catalog_name,
                    schema_name=downstream_table.schema_name,
                    table_name=downstream_table.name,
                    service_name=self.config.serviceName,
                )

                to_entity = self.metadata.get_by_name(entity=Table, fqn=to_entity_fqn)
                if to_entity:
                    downstream_table_fqn = f"{downstream_table.catalog_name}.{downstream_table.schema_name}.{downstream_table.name}"
                    lineage_details = self._get_lineage_details(
                        from_table=table,
                        to_table=to_entity,
                        databricks_table_fqn=downstream_table_fqn,
                    )
                    yield Either(
                        left=None,
                        right=AddLineageRequest(
                            edge=EntitiesEdge(
                                fromEntity=EntityReference(id=table.id, type="table"),
                                toEntity=EntityReference(id=to_entity.id, type="table"),
                                lineageDetails=lineage_details,
                            )
                        ),
                    )
                else:
                    logger.debug(
                        f"Unable to find downstream entity for "
                        f"{databricks_table_fqn} -> "
                        f"{downstream_table.catalog_name}.{downstream_table.schema_name}.{downstream_table.name}"
                    )
            except Exception:
                logger.debug(
                    "Error while processing downstream lineage for "
                    f"{databricks_table_fqn}"
                )
                logger.debug(traceback.format_exc())

    def _iter(self, *_, **__) -> Iterable[Either[AddLineageRequest]]:
        """
        Based on the query logs, prepare the lineage
        and send it to the sink
        """

        for database in self.metadata.list_all_entities(
            entity=Database, params={"service": self.config.serviceName}
        ):
            if filter_by_database(
                self.source_config.databaseFilterPattern, database.name.root
            ):
                self.status.filter(
                    database.fullyQualifiedName.root,
                    "Catalog Filtered Out",
                )
                continue
            for schema in self.metadata.list_all_entities(
                entity=DatabaseSchema,
                params={"database": database.fullyQualifiedName.root},
            ):
                if filter_by_schema(
                    self.source_config.schemaFilterPattern, schema.name.root
                ):
                    self.status.filter(
                        schema.fullyQualifiedName.root,
                        "Schema Filtered Out",
                    )
                    continue
                for table in self.metadata.list_all_entities(
                    entity=Table,
                    params={"databaseSchema": schema.fullyQualifiedName.root},
                ):
                    if filter_by_table(
                        self.source_config.tableFilterPattern, table.name.root
                    ):
                        self.status.filter(
                            table.fullyQualifiedName.root,
                            "Table Filtered Out",
                        )
                        continue

                    databricks_table_fqn = f"{table.database.name}.{table.databaseSchema.name}.{table.name.root}"
                    logger.debug(f"Processing table: {databricks_table_fqn}")
                    table_streams: LineageTableStreams = self.client.get_table_lineage(
                        databricks_table_fqn
                    )

                    # Process upstream lineage
                    yield from self._handle_upstream_table(
                        table_streams, table, databricks_table_fqn
                    )

                    # Disabling downstream lineage for now as it causes slowness
                    # Process downstream lineage
                    # yield from self._handle_downstream_table(
                    #     table_streams, table, databricks_table_fqn
                    # )

    def test_connection(self) -> None:
        test_connection_common(
            self.metadata, self.connection_obj, self.service_connection
        )
